{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Zero-collision Hash Tutorial\n",
        "This example notebook goes through the following topics:\n",
        "- Why do we need zero-collision hash?\n",
        "- How to use the zero-collision module in TorchRec?\n",
        "\n",
        "## Pre-requisite\n",
        "Before dive into the details, let's import all the necessary packages first. This needs you to [have the latest `torchrec` library installed](https://docs.pytorch.org/torchrec/setup-torchrec.html#installation)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "output": {
          "id": 1181435817001907,
          "loadingStatus": "loaded"
        }
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "I0611 161033.883 _utils_internal.py:282] NCCL_DEBUG env var is set to None\n",
            "I0611 161033.885 _utils_internal.py:291] NCCL_DEBUG is WARN from /etc/nccl.conf\n",
            "I0611 161039.736 pyper_torch_elastic_logging_utils.py:234] initialized PyperTorchElasticEventHandler\n"
          ]
        }
      ],
      "source": [
        "import torch\n",
        "from torch import nn\n",
        "from torchrec import (\n",
        "    EmbeddingCollection,\n",
        "    EmbeddingConfig,\n",
        "    JaggedTensor,\n",
        "    KeyedJaggedTensor,\n",
        "    KeyedTensor,\n",
        ")\n",
        "\n",
        "from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection\n",
        "\n",
        "from torchrec.modules.mc_modules import (\n",
        "    DistanceLFU_EvictionPolicy,\n",
        "    ManagedCollisionCollection,\n",
        "    MCHManagedCollisionModule,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Hash and Zero Collision Hash\n",
        "In this section, we present the motivation that\n",
        "- Why do we need to perform hash on incoming features?\n",
        "- Why do we need to implement zero-collision hash?\n",
        "\n",
        "Let's first take a look in the question that why do we need to perform hashing for sparse feature inputs in the recommendation model?  \n",
        "We firstly create an embedding table of 1000 embeddings."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# define the number of embeddings\n",
        "num_embeddings = 1000\n",
        "table_config = EmbeddingConfig(\n",
        "    name=\"t1\",\n",
        "    embedding_dim=16,\n",
        "    num_embeddings=1000,\n",
        "    feature_names=[\"f1\"],\n",
        ")\n",
        "ec = EmbeddingCollection(tables=[table_config])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Usually, for each input sparse feature ID, we regard it as the index of the embedding in the embedding table, and fetch the embedding at the corresponding slot in the embedding table. However, while embedding tables is fixed when instantiating the models, the number of sparse features, such as tags of videos, can keep growing. After a while, the ID of a sparse feature can be larger the size of our embedding table."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "feature_id = num_embeddings + 1\n",
        "input_kjt = KeyedJaggedTensor.from_lengths_sync(\n",
        "    keys=[\"f1\"],\n",
        "    values=torch.tensor([feature_id]),\n",
        "    lengths=torch.tensor([1]),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "At that point, the query will lead to an `index out of range` error."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "output": {
          "id": 1225052112737471,
          "loadingStatus": "loaded"
        }
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Query the embedding table of size 1000 with sparse feature ID tensor([1001])\n",
            "This query throws an IndexError: index out of range in self\n"
          ]
        }
      ],
      "source": [
        "try:\n",
        "    feature_embedding = ec(input_kjt)\n",
        "except IndexError as e:\n",
        "    print(f\"Query the embedding table of size {num_embeddings} with sparse feature ID {input_kjt['f1'].values()}\")\n",
        "    print(f\"This query throws an IndexError: {e}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "To avoid this error from happening, we hash the sparse feature ID to a value within the range of the embedding table size, and use the hashed value as the feature ID to query the embedding table. \n",
        "\n",
        "For the purpose of demonstration, we use Python's built-in hash function to hash an integer (which will not change the value) and remap it to the range of `[0, num_embeddings)` by taking the modulo of `num_embeddings`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def remap(input_jt_value: int, num_embeddings: int):\n",
        "    input_hash = hash(input_jt_value)\n",
        "    return input_hash % num_embeddings"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now we can query the embedding table with the remapped id without error."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "output": {
          "id": 990950286247121,
          "loadingStatus": "loaded"
        }
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Query the embedding table of size 1000 with remapped sparse feature ID 1 from original ID 1001\n",
            "This query does not throw an IndexError, and returns the embedding of the remapped ID: {'f1': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7fe444183bf0>}\n"
          ]
        }
      ],
      "source": [
        "remapped_id = remap(feature_id, num_embeddings)\n",
        "remapped_kjt = KeyedJaggedTensor.from_lengths_sync(\n",
        "    keys=[\"f1\"],\n",
        "    values=torch.tensor([remapped_id]),\n",
        "    lengths=torch.tensor([1]),\n",
        ")\n",
        "feature_embedding = ec(remapped_kjt)\n",
        "print(f\"Query the embedding table of size {num_embeddings} with remapped sparse feature ID {remapped_id} from original ID {feature_id}\")\n",
        "print(f\"This query does not throw an IndexError, and returns the embedding of the remapped ID: {feature_embedding}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "After answering the first question: __Why do we need to perform hash on incoming features?__, now we can answer the second question: __Why do we need to implement zero-collision hash?__\n",
        "\n",
        "Because we are casting a larger range of values into a small range, there will be some values being remapped to the same index. For example, using our `remap` function, it will give the same remapped id for feature `num_embeddings + 1` and `1`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "output": {
          "id": 1024965419837378,
          "loadingStatus": "loaded"
        }
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "feature ID 1 is remapped to ID 1\n",
            "feature ID 1001 is remapped to ID 1\n",
            "Check if remapped feature ID 1 and 1 are the same: True\n"
          ]
        }
      ],
      "source": [
        "feature_id_1 = 1\n",
        "feature_id_2 = num_embeddings + 1\n",
        "remapped_feature_id_1 = remap(feature_id_1, num_embeddings)\n",
        "remapped_feature_id_2 = remap(feature_id_2, num_embeddings)\n",
        "print(f\"feature ID {feature_id_1} is remapped to ID {remapped_feature_id_1}\")\n",
        "print(f\"feature ID {feature_id_2} is remapped to ID {remapped_feature_id_2}\")\n",
        "print(f\"Check if remapped feature ID {remapped_feature_id_1} and {remapped_feature_id_2} are the same: {remapped_feature_id_1 == remapped_feature_id_2}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In this case, two totally different features can share the same embedding. The situation when two feature IDs share the same remapped ID is called a **hash collision**."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "output": {
          "id": 923188026604331,
          "loadingStatus": "loaded"
        }
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Embedding of feature ID 1 is tensor([ 0.0232,  0.0075,  0.0281, -0.0195, -0.0301,  0.0033,  0.0303,  0.0294,\n",
            "         0.0301, -0.0287, -0.0130, -0.0194,  0.0263,  0.0287,  0.0261, -0.0080],\n",
            "       grad_fn=<SelectBackward0>)\n",
            "Embedding of feature ID 1 is tensor([ 0.0232,  0.0075,  0.0281, -0.0195, -0.0301,  0.0033,  0.0303,  0.0294,\n",
            "         0.0301, -0.0287, -0.0130, -0.0194,  0.0263,  0.0287,  0.0261, -0.0080],\n",
            "       grad_fn=<SelectBackward0>)\n",
            "Check if the embeddings of feature ID 1 and 1 are the same: True\n"
          ]
        }
      ],
      "source": [
        "input_kjt = KeyedJaggedTensor.from_lengths_sync(\n",
        "    keys=[\"f1\"],\n",
        "    values=torch.tensor([remapped_feature_id_1, remapped_feature_id_2]),\n",
        "    lengths=torch.tensor([1, 1]),\n",
        ")\n",
        "feature_embeddings = ec(input_kjt)\n",
        "feature_id_1_embedding = feature_embeddings[\"f1\"].values()[0]\n",
        "feature_id_2_embedding = feature_embeddings[\"f1\"].values()[1]\n",
        "print(f\"Embedding of feature ID {remapped_feature_id_1} is {feature_id_1_embedding}\")\n",
        "print(f\"Embedding of feature ID {remapped_feature_id_2} is {feature_id_2_embedding}\")\n",
        "print(f\"Check if the embeddings of feature ID {remapped_feature_id_1} and {remapped_feature_id_2} are the same: {torch.equal(feature_id_1_embedding, feature_id_2_embedding)}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n",
        "Making two different (and potentially totally irrelavant) features share the same embedding will cause inaccurate recommendations.\n",
        "Lukily, for many sparse features, though their range can be larger than the the embedding table size, their IDs are sparsely located on the range.\n",
        "In some other cases, the embedding table may only receive frequent queries for a subset of the features.\n",
        "So we can design some __managed collision hash__ modules to avoid the hash collision from happening."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## TorchRec Zero Collision Hash Modules\n",
        "\n",
        "TorchRec implements managed collision hash strategies such as *sorted zero collision hash* and *multi-probe zero collision hash (MPZCH)*.\n",
        "\n",
        "They help hash and remap the feature IDs to embedding table indices with (near-)zero collisions.\n",
        "\n",
        "In the following content we will use the MPZCH module as an example for how to use the zero-collision modules in TorchRec. The name of the MPZCH module is `HashZchManagedCollisionModule`.\n",
        "\n",
        "Let's assume we have two tables: `table_0` and `table_1`, each with embeddings for `feature_0` and `feature_1`, respectively."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# define the table sizes\n",
        "num_embeddings_table_0 = 1000\n",
        "num_embeddings_table_1 = 2000\n",
        "\n",
        "# create table configs\n",
        "table_0_config = EmbeddingConfig(\n",
        "    name=\"table_0\",\n",
        "    embedding_dim=16,\n",
        "    num_embeddings=num_embeddings_table_0,\n",
        "    feature_names=[\"feature_0\"],\n",
        ")\n",
        "\n",
        "table_1_config = EmbeddingConfig(\n",
        "    name=\"table_1\",\n",
        "    embedding_dim=16,\n",
        "    num_embeddings=num_embeddings_table_1,\n",
        "    feature_names=[\"feature_1\"],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Before turning the table configurations into embedding table collection, we instantiate our managed collision modules.\n",
        "\n",
        "The managed collision modules for a collection of embedding tables are intended to format as a dictionary with `{table_name: mc_module_for_the_table}`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "mc_modules = {}\n",
        "\n",
        "# Instantiate the module, we provide detailed comments on\n",
        "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') #\n",
        "input_hash_size = 10000\n",
        "mc_modules[\"table_0\"] = MCHManagedCollisionModule(\n",
        "                zch_size=(table_0_config.num_embeddings),\n",
        "                input_hash_size=input_hash_size,\n",
        "                device=device,\n",
        "                eviction_interval=2,\n",
        "                eviction_policy=DistanceLFU_EvictionPolicy(),\n",
        "            )\n",
        "mc_modules[\"table_1\"] = MCHManagedCollisionModule(\n",
        "                zch_size=(table_1_config.num_embeddings),\n",
        "                device=device,\n",
        "                input_hash_size=input_hash_size,\n",
        "                eviction_interval=1,\n",
        "                eviction_policy=DistanceLFU_EvictionPolicy(),\n",
        "            )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "For embedding tables with managed collision modules, TorchRec uses a wrapper module `ManagedCollisionEmbeddingCollection` that contains both the embedding table collections and the managed collision modules. Users only need to pass their table configurations and"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "mc_ec = ManagedCollisionEmbeddingCollection = (\n",
        "            ManagedCollisionEmbeddingCollection(\n",
        "                EmbeddingCollection(\n",
        "                    tables=[\n",
        "                        table_0_config,\n",
        "                        table_1_config\n",
        "                    ],\n",
        "                    device=device,\n",
        "                ),\n",
        "                ManagedCollisionCollection(\n",
        "                    managed_collision_modules=mc_modules,\n",
        "                    embedding_configs=[\n",
        "                        table_0_config,\n",
        "                        table_1_config\n",
        "                    ],\n",
        "                ),\n",
        "                return_remapped_features=True, # whether to return the remapped feature IDs\n",
        "            )\n",
        "        )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The `ManagedCollisionEmbeddingCollection` module will perform remapping and table look-up for the input. Users only need to pass the keyyed jagged tensor queries into the module."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "output": {
          "id": 1363945501556497,
          "loadingStatus": "loaded"
        }
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "feature name: feature_0, feature jt: JaggedTensor({\n",
            "    [[1000], [10001]]\n",
            "})\n",
            "\n",
            "feature jt values: tensor([ 1000, 10001])\n",
            "feature name: feature_1, feature jt: JaggedTensor({\n",
            "    [[2000], [20001]]\n",
            "})\n",
            "\n",
            "feature jt values: tensor([ 2000, 20001])\n"
          ]
        }
      ],
      "source": [
        "input_kjt = KeyedJaggedTensor.from_lengths_sync(\n",
        "    keys=[\"feature_0\", \"feature_1\"],\n",
        "    values=torch.tensor([1000, 10001, 2000, 20001]),\n",
        "    lengths=torch.tensor([1, 1, 1, 1]),\n",
        ")\n",
        "for feature_name, feature_jt in input_kjt.to_dict().items():\n",
        "    print(f\"feature name: {feature_name}, feature jt: {feature_jt}\")\n",
        "    print(f\"feature jt values: {feature_jt.values()}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "output": {
          "id": 1555795345806679,
          "loadingStatus": "loaded"
        }
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "feature name: feature_0, feature embedding: JaggedTensor({\n",
            "    [[[0.022659072652459145, 0.0053002419881522655, -0.025007368996739388, -0.013145492412149906, -0.031139537692070007, -0.01486812811344862, -0.01133741531521082, 0.0027838051319122314, 0.026786740869283676, -0.010626785457134247, 0.01148480549454689, 0.02036162279546261, 0.013492186553776264, -0.024412740021944046, 0.01599711738526821, -0.02390478551387787]], [[-0.029269251972436905, 0.01744556427001953, 0.024260954931378365, 0.029459983110427856, -0.026435773819684982, -0.0034603318199515343, -0.007642757147550583, -0.02111411839723587, 0.027316255494952202, 0.015309474430978298, 0.03137263283133507, 0.01699884422123432, 0.02302604913711548, -0.015266639180481434, -0.019045181572437286, 0.006964980624616146]]]\n",
            "})\n",
            "\n",
            "feature name: feature_1, feature embedding: JaggedTensor({\n",
            "    [[[0.009506281465291977, 0.012826820835471153, -0.0017535268561914563, -0.0009170559933409095, -0.014913717284798622, 0.0040654330514371395, -0.011355634778738022, 0.008443576283752918, 0.0007347835344262421, -0.00907053705304861, 0.010160156525671482, 0.016830360516905785, 0.002154064131900668, -0.010799579322338104, -0.0197420883923769, -0.0025849745143204927]], [[-0.020103629678487778, 0.01041398011147976, -0.01699216105043888, 0.01291638519614935, 0.018798867240548134, 0.01616138033568859, -0.020600538700819016, -0.017695769667625427, 0.0100017711520195, -0.010470695793628693, -0.018935278058052063, -0.011798662133514881, -0.014235826209187508, -0.01985463872551918, 0.009744714014232159, -0.004050525836646557]]]\n",
            "})\n",
            "\n",
            "feature name: feature_0, feature jt values: tensor([997, 998], device='cuda:0')\n",
            "feature name: feature_1, feature jt values: tensor([1997, 1998], device='cuda:0')\n"
          ]
        }
      ],
      "source": [
        "output_embeddings, remapped_ids = mc_ec(input_kjt.to(device))\n",
        "# show output embeddings\n",
        "for feature_name, feature_embedding in output_embeddings.items():\n",
        "    print(f\"feature name: {feature_name}, feature embedding: {feature_embedding}\")\n",
        "# show remapped ids\n",
        "for feature_name, feature_jt in remapped_ids.to_dict().items():\n",
        "    print(f\"feature name: {feature_name}, feature jt values: {feature_jt.values()}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now we have a basic example of how to use the managed collision modules in TorchRec. \n",
        "\n",
        "We also provide a profiling example to compare the efficiency of sorted ZCH and MPZCH modules. Check the [Readme](Readme.md) file for more details."
      ]
    }
  ],
  "metadata": {
    "fileHeader": "",
    "fileUid": "ee3845a2-a85b-4a8e-8c42-9ce4690c9956",
    "isAdHoc": false,
    "kernelspec": {
      "display_name": "torchrec",
      "language": "python",
      "name": "bento_kernel_torchrec"
    },
    "language_info": {
      "name": "plaintext"
    },
    "orig_nbformat": 4
  }
}
